# wine_znkd.py
#
# Full ZNKD pipeline for UCI Wine-Quality (red + white):
#  - Teacher: 10-qubit VQC trained on true quality labels
#  - ZNE:    zero-noise energies per class (teacher)
#  - Targets: tanh-stabilized regression labels from ZNE-corrected energies
#  - Student: 6-qubit QNN trained via regression on those soft targets

import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split

from qiskit_aer import AerSimulator
from qiskit.utils import QuantumInstance
from qiskit.circuit.library import EfficientSU2
from qiskit import QuantumCircuit

from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.algorithms import VQC, NeuralNetworkRegressor

from wine_zne_utils import zne_expectation_zero

SEED = 123
TEACHER_QUBITS = 10
STUDENT_QUBITS = 6
PCA_COMPONENTS = TEACHER_QUBITS  # 10 -> 10 qubits


# =========================================================
# 1. Load Wine-Quality, preprocess, PCA, angle-encode
# =========================================================
def load_wine_pca():
    """Load red + white wine-quality CSVs, apply scaling, PCA, and encoding."""
    print("⇨ Downloading UCI wine-quality CSVs …")
    red = pd.read_csv(
        "https://archive.ics.uci.edu/ml/machine-learning-databases/"  # noqa: E501
        "wine-quality/winequality-red.csv",
        sep=";",
    )
    white = pd.read_csv(
        "https://archive.ics.uci.edu/ml/machine-learning-databases/"  # noqa: E501
        "wine-quality/winequality-white.csv",
        sep=";",
    )
    data = pd.concat([red, white], ignore_index=True)

    X = data.drop(columns=["quality"]).values.astype(np.float32)
    y = data["quality"].values.astype(int)  # quality scores 3–8

    # Standardize features
    X = StandardScaler().fit_transform(X)

    # PCA down to number of qubits
    pca = PCA(n_components=PCA_COMPONENTS, random_state=SEED)
    X = pca.fit_transform(X)

    # Map to [0, π] for angle encoding
    X = np.pi * (X - X.min()) / (X.max() - X.min() + 1e-12)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, stratify=y, random_state=SEED
    )
    print(f"   • Train size: {X_train.shape[0]}")
    print(f"   • Test  size: {X_test.shape[0]}")
    return X_train, X_test, y_train, y_test


def make_feature_map(num_qubits):
    """Simple Ry angle encoding on the first num_qubits features."""
    def fm(x):
        qc = QuantumCircuit(num_qubits)
        for i, theta in enumerate(x[:num_qubits]):
            qc.ry(float(theta), i)
        return qc
    return fm


# =========================================================
# 2. Teacher VQC
# =========================================================
def build_teacher_vqc(n_classes):
    ansatz = EfficientSU2(TEACHER_QUBITS, reps=2)
    backend = AerSimulator(seed_simulator=SEED)
    qinst = QuantumInstance(
        backend=backend,
        seed_simulator=SEED,
        seed_transpiler=SEED,
    )

    vqc = VQC(
        feature_map=make_feature_map(TEACHER_QUBITS),
        ansatz=ansatz,
        optimizer="COBYLA",
        quantum_instance=qinst,
        num_classes=n_classes,
    )
    return vqc


# =========================================================
# 3. Student QNN regressor (ZNKD student)
# =========================================================
def build_student_regressor():
    ansatz = EfficientSU2(STUDENT_QUBITS, reps=1)

    qnn = SamplerQNN(
        circuit=ansatz,
        input_params=ansatz.parameters[:STUDENT_QUBITS],
        weight_params=ansatz.parameters[STUDENT_QUBITS:],
        sparse=False,
    )

    backend = AerSimulator(seed_simulator=SEED)
    qinst = QuantumInstance(
        backend=backend,
        seed_simulator=SEED,
        seed_transpiler=SEED,
    )

    regressor = NeuralNetworkRegressor(
        neural_network=qnn,
        loss="l2",
        optimizer="COBYLA",
        quantum_instance=qinst,
    )
    return regressor


# =========================================================
# 4. ZNE-based tanh targets
# =========================================================
def compute_zne_tanh_targets(vqc, X, base_eps=0.01, tau=1.0):
    """Compute tanh-stabilized ZNE energies for each training example."""
    all_targets = []
    for i, x in enumerate(X):
        if (i + 1) % 100 == 0:
            print(f"[ZNE targets] {i+1}/{len(X)} samples", flush=True)

        energies = zne_expectation_zero(vqc, x, base_eps=base_eps, seed=SEED)
        stabilized = np.tanh(energies / tau)
        all_targets.append(stabilized)

    return np.stack(all_targets, axis=0)


# =========================================================
# 5. Accuracy helpers
# =========================================================
def teacher_accuracy(vqc, X, y):
    y_hat = vqc.predict(X)
    return (y_hat == y).mean()


def student_accuracy(student_regressor, X, y):
    preds = student_regressor.predict(X)
    # Lower energy ⇒ higher confidence ⇒ argmin
    y_hat = np.argmin(preds, axis=1)
    return (y_hat == y).mean()


# =========================================================
# 6. Main ZNKD pipeline
# =========================================================
def main():
    X_train, X_test, y_train, y_test = load_wine_pca()
    n_classes = len(np.unique(y_train))

    print("⇨ Building teacher VQC …")
    teacher = build_teacher_vqc(n_classes=n_classes)

    print("⇨ Training teacher on true labels …")
    teacher.fit(X_train, y_train)

    print("\n=== Teacher accuracy (noiseless simulator) ===")
    acc_teacher = teacher_accuracy(teacher, X_test, y_test)
    print(f"Teacher: {acc_teacher:.3f}")

    print("\n⇨ Computing ZNE-based tanh targets for distillation …")
    zne_targets = compute_zne_tanh_targets(
        teacher, X_train, base_eps=0.01, tau=1.0
    )

    print("⇨ Building student QNN regressor …")
    student = build_student_regressor()

    print("⇨ Training student on ZNE-tanh targets …")
    student.fit(X_train, zne_targets)

    print("\n=== Distilled student accuracy ===")
    acc_student = student_accuracy(student, X_test, y_test)
    print(f"Student (ZNKD regression): {acc_student:.3f}")


if __name__ == "__main__":
    main()
